{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Amazon SageMaker XGBoost Bring Your Own Model\n",
    "_**Hosting a Pre-Trained scikit-learn Model in Amazon SageMaker XGBoost Algorithm Container**_\n",
    "\n",
    "---\n",
    "\n",
    "---\n",
    "\n",
    "## Contents\n",
    "\n",
    "1. [Background](#Background)\n",
    "1. [Setup](#Setup)\n",
    "1. [Optionally, train a scikit learn XGBoost model](#Optionally,-train-a-scikit-learn-XGBoost-model)\n",
    "1. [Upload the pre-trained model to S3](#Upload-the-pre-trained-model-to-S3)\n",
    "1. [Set up hosting for the model](#Set-up-hosting-for-the-model)\n",
    "1. [Validate the model for use](#Validate-the-model-for-use)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "---\n",
    "## Background\n",
    "\n",
    "Amazon SageMaker includes functionality to support a hosted notebook environment, distributed, serverless training, and real-time hosting. We think it works best when all three of these services are used together, but they can also be used independently.  Some use cases may only require hosting.  Maybe the model was trained prior to Amazon SageMaker existing, in a different service.\n",
    "\n",
    "This notebook shows how to use a pre-existing scikit-learn trained XGBoost model with the Amazon SageMaker XGBoost Algorithm container to quickly create a hosted endpoint for that model. Please note that scikit-learn XGBoost model is compatible with SageMaker XGBoost container, whereas other gradient boosted tree models (such as one trained in SparkML) are not.\n",
    "\n",
    "---\n",
    "## Setup\n",
    "\n",
    "Let's start by specifying:\n",
    "\n",
    "* AWS region.\n",
    "* The IAM role arn used to give learning and hosting access to your data. See the documentation for how to specify these.\n",
    "* The S3 bucket that you want to use for training and model data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "isConfigCell": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "import os\n",
    "import boto3\n",
    "import re\n",
    "import json\n",
    "from sagemaker import get_execution_role\n",
    "\n",
    "region = boto3.Session().region_name\n",
    "\n",
    "role = get_execution_role()\n",
    "\n",
    "bucket='<s3 bucket>' # put your s3 bucket name here, and create s3 bucket\n",
    "prefix = 'sagemaker/DEMO-xgboost-byo'\n",
    "bucket_path = 'https://s3-{}.amazonaws.com/{}'.format(region,bucket)\n",
    "# customize to your bucket where you have stored the data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optionally, train a scikit learn XGBoost model\n",
    "\n",
    "These steps are optional and are needed to generate the scikit-learn model that will eventually be hosted using the SageMaker Algorithm contained. \n",
    "\n",
    "### Install XGboost\n",
    "Note that for conda based installation, you'll need to change the Notebook kernel to the environment with conda and Python3. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "!conda install -y -c conda-forge xgboost"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fetch the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "import pickle, gzip, numpy, urllib.request, json\n",
    "\n",
    "# Load the dataset\n",
    "urllib.request.urlretrieve(\"http://deeplearning.net/data/mnist/mnist.pkl.gz\", \"mnist.pkl.gz\")\n",
    "f = gzip.open('mnist.pkl.gz', 'rb')\n",
    "train_set, valid_set, test_set = pickle.load(f, encoding='latin1')\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare the dataset for training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "import struct\n",
    "import io\n",
    "import boto3\n",
    "\n",
    "def get_dataset():\n",
    "  import pickle\n",
    "  import gzip\n",
    "  with gzip.open('mnist.pkl.gz', 'rb') as f:\n",
    "      u = pickle._Unpickler(f)\n",
    "      u.encoding = 'latin1'\n",
    "      return u.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_set, valid_set, test_set = get_dataset()\n",
    "\n",
    "train_X = train_set[0]\n",
    "train_y = train_set[1]\n",
    "\n",
    "valid_X = valid_set[0]\n",
    "valid_y = valid_set[1]\n",
    "\n",
    "test_X = test_set[0]\n",
    "test_y = test_set[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the XGBClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import xgboost as xgb\n",
    "import sklearn as sk \n",
    "\n",
    "bt = xgb.XGBClassifier(max_depth=5,\n",
    "                       learning_rate=0.2,\n",
    "                       n_estimators=10,\n",
    "                       objective='multi:softmax')   # Setup xgboost model\n",
    "bt.fit(train_X, train_y, # Train it to our data\n",
    "       eval_set=[(valid_X, valid_y)], \n",
    "       verbose=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save the trained model file\n",
    "Note that the model file name must satisfy the regular expression pattern: `^[a-zA-Z0-9](-*[a-zA-Z0-9])*;`. The model file also need to tar-zipped. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model_file_name = \"DEMO-local-xgboost-model\"\n",
    "bt._Booster.save_model(model_file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "!tar czvf model.tar.gz $model_file_name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Upload the pre-trained model to S3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "fObj = open(\"model.tar.gz\", 'rb')\n",
    "key= os.path.join(prefix, model_file_name, 'model.tar.gz')\n",
    "boto3.Session().resource('s3').Bucket(bucket).Object(key).upload_fileobj(fObj)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up hosting for the model\n",
    "\n",
    "### Import model into hosting\n",
    "This involves creating a SageMaker model from the model file previously uploaded to S3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sagemaker.amazon.amazon_estimator import get_image_uri\n",
    "container = get_image_uri(boto3.Session().region_name, 'xgboost')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "from time import gmtime, strftime\n",
    "\n",
    "model_name = model_file_name + strftime(\"%Y-%m-%d-%H-%M-%S\", gmtime())\n",
    "model_url = 'https://s3-{}.amazonaws.com/{}/{}'.format(region,bucket,key)\n",
    "sm_client = boto3.client('sagemaker')\n",
    "\n",
    "print (model_url)\n",
    "\n",
    "primary_container = {\n",
    "    'Image': container,\n",
    "    'ModelDataUrl': model_url,\n",
    "}\n",
    "\n",
    "create_model_response2 = sm_client.create_model(\n",
    "    ModelName = model_name,\n",
    "    ExecutionRoleArn = role,\n",
    "    PrimaryContainer = primary_container)\n",
    "\n",
    "print(create_model_response2['ModelArn'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create endpoint configuration\n",
    "\n",
    "SageMaker supports configuring REST endpoints in hosting with multiple models, e.g. for A/B testing purposes. In order to support this, you can create an endpoint configuration, that describes the distribution of traffic across the models, whether split, shadowed, or sampled in some way. In addition, the endpoint configuration describes the instance type required for model deployment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from time import gmtime, strftime\n",
    "\n",
    "endpoint_config_name = 'DEMO-XGBoostEndpointConfig-' + strftime(\"%Y-%m-%d-%H-%M-%S\", gmtime())\n",
    "print(endpoint_config_name)\n",
    "create_endpoint_config_response = sm_client.create_endpoint_config(\n",
    "    EndpointConfigName = endpoint_config_name,\n",
    "    ProductionVariants=[{\n",
    "        'InstanceType':'ml.m4.xlarge',\n",
    "        'InitialInstanceCount':1,\n",
    "        'InitialVariantWeight':1,\n",
    "        'ModelName':model_name,\n",
    "        'VariantName':'AllTraffic'}])\n",
    "\n",
    "print(\"Endpoint Config Arn: \" + create_endpoint_config_response['EndpointConfigArn'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create endpoint\n",
    "Lastly, you create the endpoint that serves up the model, through specifying the name and configuration defined above. The end result is an endpoint that can be validated and incorporated into production applications. This takes 9-11 minutes to complete."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "import time\n",
    "\n",
    "endpoint_name = 'DEMO-XGBoostEndpoint-' + strftime(\"%Y-%m-%d-%H-%M-%S\", gmtime())\n",
    "print(endpoint_name)\n",
    "create_endpoint_response = sm_client.create_endpoint(\n",
    "    EndpointName=endpoint_name,\n",
    "    EndpointConfigName=endpoint_config_name)\n",
    "print(create_endpoint_response['EndpointArn'])\n",
    "\n",
    "resp = sm_client.describe_endpoint(EndpointName=endpoint_name)\n",
    "status = resp['EndpointStatus']\n",
    "print(\"Status: \" + status)\n",
    "\n",
    "while status=='Creating':\n",
    "    time.sleep(60)\n",
    "    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)\n",
    "    status = resp['EndpointStatus']\n",
    "    print(\"Status: \" + status)\n",
    "\n",
    "print(\"Arn: \" + resp['EndpointArn'])\n",
    "print(\"Status: \" + status)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validate the model for use\n",
    "Now you can obtain the endpoint from the client library using the result from previous operations and generate classifications from the model using that endpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "runtime_client = boto3.client('runtime.sagemaker')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets generate the prediction for a single datapoint. We'll pick one from the test data generated earlier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "point_X = test_X[0]\n",
    "point_X = np.expand_dims(point_X, axis=0)\n",
    "point_y = test_y[0]\n",
    "np.savetxt(\"test_point.csv\", point_X, delimiter=\",\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "import json\n",
    "\n",
    "\n",
    "file_name = 'test_point.csv' #customize to your test file, will be 'mnist.single.test' if use data above\n",
    "\n",
    "with open(file_name, 'r') as f:\n",
    "    payload = f.read().strip()\n",
    "\n",
    "response = runtime_client.invoke_endpoint(EndpointName=endpoint_name, \n",
    "                                   ContentType='text/csv', \n",
    "                                   Body=payload)\n",
    "result = response['Body'].read().decode('ascii')\n",
    "print('Predicted Class Probabilities: {}.'.format(result))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Post process the output\n",
    "Since the result is a string, let's process it to determine the the output class label. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "floatArr = np.array(json.loads(result))\n",
    "predictedLabel = np.argmax(floatArr)\n",
    "print('Predicted Class Label: {}.'.format(predictedLabel))\n",
    "print('Actual Class Label: {}.'.format(point_y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (Optional) Delete the Endpoint\n",
    "\n",
    "If you're ready to be done with this notebook, please run the delete_endpoint line in the cell below.  This will remove the hosted endpoint you created and avoid any charges from a stray instance being left on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sm_client.delete_endpoint(EndpointName=endpoint_name)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  },
  "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.  Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
